Amazon SageMakerでTensorFlowを使ってIris分類してみた

Amazon SageMakerでTensorFlowを使ってIris分類してみた

Clock Icon2018.09.03

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

はじめに

SageMaker Python SDKを使うことでTensorFlowの学習をAmazon SageMaker上で簡単に行えます。
今回は、AWSが公開しているTensorFlowをAmazon SageMakerで使って分類モデルを学習させるを実際にやってみたので、紹介していきたいと思います。

概要

TensorFlowのDNNClassifierモデルをAmazon SageMakerで学習させて、Iris(あやめ)のデータセットを分類します。

  • Irisのデータセットは、花のがく片と花びらそれぞれの長さと幅のデータに加えて、品種(3種類)が入っています。データ分析のチュートリアルでよく使われるデータセットの一つです。
  • DNNClassifierはTensorFlowのニューラルネットワークを使った分類モデルです。高レベルAPIなので、分類モデルを簡単に使うことができます。

やってみた

準備

ノートブックの作成

SageMakerのノートブックインスタンスを立ち上げて表示されるjupyterのトップページのタブから SageMaker Examples

SageMaker Python SDK

tensorflow_iris_dnn_classifier_using_estimators.ipynb

use
でサンプルのノートブックをコピーして、開きます。
ノートブックインスタンスの作成についてはこちらをご参照ください。

環境変数とロールの確認

学習データ等を保存するS3の場所の指定と、学習やエンドポイントを立ち上げる際に使用するIAMロールの取得を行います。

from sagemaker import get_execution_role

#Bucket location to save your custom code in tar.gz format.
custom_code_upload_location = 's3://<bucket-name>/customcode/tensorflow_iris'

#Bucket location where results of model training are saved.
model_artifacts_location = 's3://<bucket-name>/artifacts'

#IAM execution role that gives SageMaker access to resources in your AWS account.
role = get_execution_role()

学習

entry_point(学習用スクリプト)について

スクリプトファイルiris_dnn_classifier.pyを学習時のentry_pointとして設定します。このスクリプトの中にTensorFlowを使ってモデルの定義などを行います。

entry_pointとして設定するスクリプトファイルには次の関数が定義されている必要があります。

estimator_fn
学習するモデルを定義します。今回の場合は中間層が三層のtf.estimator.DNNClassifierのモデルを使用します。
train_input_fn
学習データを読み込むための処理を定義します。学習データはs3にアップされているsagemaker-sample-data-ap-northeast-1/tensorflow/iris_training.csvを使います。
eval_input_fn
評価データを読み込むための処理を定義します。検証用データはs3にアップされているsagemaker-sample-data-ap-northeast-1/tensorflow/iris_test.csvを使います。
serving_input_fn
推論時にモデルに入力されるデータの形式を定義したものです。この関数は必須ではありませんが、Amazon SageMakerで学習させたモデルを展開する場合は必要です *1

※詳細についてはSDKのGitHubでの説明をご参照いただければと思います。例を交えた説明があり、分かりやすいです。

今回、entry_pointとして使用するスクリプトは以下の通りです。

import numpy as np
import os
import tensorflow as tf

INPUT_TENSOR_NAME = 'inputs'

# より良いパフォーマンスを得るためにMKLを無効化する(参考: https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/tensorflow#training-with-mkl-dnn-disabled)
os.environ['TF_DISABLE_MKL'] = '1'
os.environ['TF_DISABLE_POOL_ALLOCATOR'] = '1'

# モデルの設定
def estimator_fn(run_config, params):
    feature_columns = [tf.feature_column.numeric_column(INPUT_TENSOR_NAME, shape=[4])]
    return tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                      hidden_units=[10, 20, 10],
                                      n_classes=3,
                                      config=run_config)

# 推論時の入力データ形式を定義
def serving_input_fn(params):
    feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=[4])}
    return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)()

# 学習用データ入力関数を返す
def train_input_fn(training_dir, params):
    """Returns input function that would feed the model during training"""
    return _generate_input_fn(training_dir, 'iris_training.csv')

# 評価用データ入力関数を返す
def eval_input_fn(training_dir, params):
    """Returns input function that would feed the model during evaluation"""
    return _generate_input_fn(training_dir, 'iris_test.csv')

# csv形式のデータを読み込んで、モデルに入力データを渡すための入力関数を作成する
def _generate_input_fn(training_dir, training_filename):
    training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
        filename=os.path.join(training_dir, training_filename),
        target_dtype=np.int,
        features_dtype=np.float32)

    return tf.estimator.inputs.numpy_input_fn(
        x={INPUT_TENSOR_NAME: np.array(training_set.data)},
        y=np.array(training_set.target),
        num_epochs=None,
        shuffle=True)()

パラメータの設定

学習に向けて、パラメータを設定します。

from sagemaker.tensorflow import TensorFlow

iris_estimator = TensorFlow(entry_point='iris_dnn_classifier.py', # 学習用スクリプトファイルを指定(スクリプトファイルが複数の場合はsource_dirで指定出来ます)
                            role=role, # 学習やエンドポイントの作成に使用するIAMロール名
                            framework_version='1.9', # 使用するTensorFlowのバージョン
                            output_path=model_artifacts_location, # モデルアーティファクトの出力先
                            code_location=custom_code_upload_location, # スクリプトを保存する場所
                            train_instance_count=1, # 学習時に使用するインスタンス数
                            train_instance_type='ml.c4.xlarge', # 学習時に使用するインスタンスタイプ
                            training_steps=1000, # 学習のステップ数
                            evaluation_steps=100 ) # 評価のステップ数

学習の実行

設定した内容に基づいて学習を実行します。
SageMakerがインスタンスを立ち上げて、学習処理を実行し、学習終了後にインスタンスを自動的に終了させます。
学習状態は随時ログが出て来るので、追うことができます。

%%time
import boto3

# use the region-specific sample data bucket
region = boto3.Session().region_name
train_data_location = 's3://sagemaker-sample-data-{}/tensorflow/iris'.format(region)

iris_estimator.fit(train_data_location)

モデルの展開

エンドポイントを作成し、先ほど学習させたモデルをエンドポイントに展開します。

%%time
iris_predictor = iris_estimator.deploy(initial_instance_count=1,
                                       instance_type='ml.m4.xlarge')

推論

実際にデータをエンドポイントに投げて、予測してみます。正しいラベル値は1となるデータです。

iris_predictor.predict([6.4, 3.2, 4.5, 1.5]) #expected label to be 1

予測結果のレスポンスには、resultというキーで各ラベルに対する確率が入っています。
予測結果を見てみると、ラベル値1の確率が99%を超えています。一つのデータではありますが、正しく分類が出来ました。

エンドポイントの削除

余計な費用がかからないように、エンドポイントを削除します。

import sagemaker

sagemaker.Session().delete_endpoint(iris_predictor.endpoint)

さいごに

今回はAmazon SageMaker上でのTensorFlowを使ったIrisの分類方法について紹介しました。SageMaker Python SDKを利用することで、TensorFlowを使った学習が簡単に行うことができました。

また、MXNet、Chainer、PyTorchに関しても、今回紹介したTensorFlow同様にAmazon SageMaker上での学習が可能です *2。それらについても追って試してみたいと思います。

最後までお読みいただき、ありがとうございましたー!

参考

脚注

  1. この関数は推論のために必要なものですが、推論時ではなく学習の最後に呼ばれるようです。
  2. 2018年9月2日時点での情報です。

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.